#coding:utf-8
# make dataset for height recognition
import os
from glob import glob
import numpy as np
from os.path import join
from ipdb import set_trace


src_dir = 'selected_rec'

train_f = open('train_1216.txt', 'w')
test_f = open('test_1216.txt', 'w')

files = glob(join(src_dir, '*.jpg'))
files = [x.split('/')[-1] for x in files]

train_lines = []
test_lines = []

train_ratio = 0.6

for file in files:

    height = int(file.split('-')[-1].split('.')[0])
    if np.random.rand() > train_ratio:
        train_lines.append('%s,%d\n' % (file, height))
    else:
        test_lines.append('%s,%d\n' % (file, height))

train_f.writelines(train_lines)
test_f.writelines(test_lines)




